# --------------------------------------------------------
# MTLoRA
# GitHub: https://github.com/scale-lab/MTLoRA
# Built upon Swin Transformer (https://github.com/microsoft/Swin-Transformer)
#
# Original file:
# Copyright (c) 2021 Microsoft
# Licensed under the MIT License
# Written by Ze Liu
#
# Modifications:
# Copyright (c) 2024 SCALE Lab, Brown University
# Licensed under the MIT License (see LICENSE for details)
# --------------------------------------------------------

import argparse
import datetime
import json
import os
import random
import time

import numpy as np
import pynvml
import torch
import torch.backends.cudnn as cudnn
from ptflops import get_model_complexity_info
from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy
from timm.utils import AverageMeter

from config import get_config
from data import build_loader
from evaluation.evaluate_utils import PerformanceMeter, get_output
from logger import create_logger
from lr_scheduler import build_scheduler
from models import build_model, build_mtl_model
from models.lora import mark_only_lora_as_trainable
from mtl_loss_schemes import MultiTaskLoss, get_loss
from optimizer import build_optimizer
from utils import load_checkpoint, load_pretrained, save_checkpoint, NativeScalerWithGradNormCount, auto_resume_helper

# import torch.distributed as dist

try:
    import wandb

    wandb_available = True
except ImportError:
    print("Warning: wandb library not found. Logging is disabled.")
    wandb_available = False


def parse_option():
    parser = argparse.ArgumentParser(
        'Swin Transformer training and evaluation script', add_help=False)
    parser.add_argument('--cfg', type=str, required=True,
                        metavar="FILE", help='path to config file', )
    parser.add_argument(
        "--opts",
        help="Modify config options by adding 'KEY VALUE' pairs. ",
        default=None,
        nargs='+',
    )

    # easy config modification
    parser.add_argument('--batch-size', type=int,
                        help="batch size for single GPU")
    parser.add_argument('--ckpt-freq', type=int, default=5,
                        help="checkpoint saving frequency")
    parser.add_argument('--eval-freq', type=int, default=5,
                        help="model evaluation frequency")
    parser.add_argument('--epochs', type=int, default=300,
                        help="number of epochs to train")
    parser.add_argument('--data-path', type=str, help='path to dataset')
    parser.add_argument('--zip', action='store_true',
                        help='use zipped dataset instead of folder dataset')
    parser.add_argument('--cache-mode', type=str, default='part', choices=['no', 'full', 'part'],
                        help='no: no cache, '
                             'full: cache all data, '
                             'part: sharding the dataset into nonoverlapping pieces and only cache one piece')
    parser.add_argument('--pretrained',
                        help='pretrained weight from checkpoint, could be imagenet22k pretrained weight')
    parser.add_argument('--resume', help='resume from checkpoint')
    parser.add_argument('--accumulation-steps', type=int,
                        help="gradient accumulation steps")
    parser.add_argument('--use-checkpoint', action='store_true',
                        help="whether to use gradient checkpointing to save memory")
    parser.add_argument('--disable_amp', action='store_true',
                        help='Disable pytorch amp')
    parser.add_argument('--amp-opt-level', type=str, choices=['O0', 'O1', 'O2'],
                        help='mixed precision opt level, if O0, no amp is used (deprecated!)')
    parser.add_argument('--output', default='output', type=str, metavar='PATH',
                        help='root of output folder, the full path is <output>/<model_name>/<tag> (default: output)')
    parser.add_argument('--name', type=str, help='override model name')
    parser.add_argument('--tag', help='tag of experiment')
    parser.add_argument('--eval', action='store_true',
                        help='Perform evaluation only')
    parser.add_argument('--throughput', action='store_true',
                        help='Test throughput only')
    # distributed training
    parser.add_argument("--local_rank", type=int, default=0,
                        help='local rank for DistributedDataParallel')

    # for acceleration
    parser.add_argument('--fused_window_process', action='store_true',
                        help='Fused window shift & window partition, similar for reversed part.')
    parser.add_argument('--fused_layernorm',
                        action='store_true', help='Use fused layernorm.')
    # overwrite optimizer in config (*.yaml) if specified, e.g., fused_adam/fused_lamb
    parser.add_argument('--optim', type=str,
                        help='overwrite optimizer if provided, can be adamw/sgd/fused_adam/fused_lamb.')

    # MTL Config
    parser.add_argument('--tasks', type=str, default='depth',
                        help='List of tasks to run in MTL setup.')
    parser.add_argument(
        '--nyud', type=str, help='specify the path to load NYUD, replaces --data-path')
    parser.add_argument(
        '--pascal', type=str, help='specify the path to load PASCAL, replaces --data-path and --nyud')
    parser.add_argument('--eval-training-freq', type=int,
                        help='calculate performance score on the training dataset')
    parser.add_argument('--resume-backbone',
                        help='resume checkpoint into the backbone')
    parser.add_argument('--freeze-backbone',
                        action='store_true', help='Freeze encoder layers.')

    parser.add_argument('--skip_initial_validation', action='store_true',
                        help='Skip running validation at the start')
    parser.add_argument('--decoder_map', type=str,
                        help='Path to JSON file containing the type of decoder heads')
    parser.add_argument('--skip_decoder', action='store_true',
                        help='Skip loading decoder head weights')
    parser.add_argument('--disable_wandb', action='store_true',
                        help='Disable wandb logging.')
    parser.add_argument('--run_name', type=str,
                        help='wandb run name')
    parser.add_argument('--no_eval_50', action='store_false',
                        help='Disable the initial eval at 50 epochs.')
    # DWA 新增参数
    # parser.add_argument('--dwa', type=float, default=1.0,
    #                     help='Enable DWA and set temperature')

    args = parser.parse_args()

    config = get_config(args)

    return args, config


def main(config):
    dataset_train, dataset_val, data_loader_train, data_loader_val, mixup_fn = build_loader(
        config)

    logger.info(f"Creating model:{config.MODEL.TYPE}/{config.MODEL.NAME}")
    teacher = None
    # 通过config来设置模型
    model = build_model(config)
    if config.MTL:
        model = build_mtl_model(model, config)

    n_parameters = sum(p.numel() for p in model.parameters())
    logger.info(f"number of params: {n_parameters / 1e6} M")

    model.cuda()
    # 计算模型的复杂度
    # macs: 模型的总计算量（Multiply-Accumulate Operations）
    # params：模型的可训练参数总量
    macs, params = get_model_complexity_info(model, (3, config.DATA.IMG_SIZE, config.DATA.IMG_SIZE), as_strings=True,
                                             print_per_layer_stat=False, verbose=False)

    logger.info(f"ptflops GMACS = {macs} and params = {params}")

    model_without_ddp = model

    # 优化器
    optimizer = build_optimizer(config, model)

    # 自动处理混合精度训练的梯度缩放
    loss_scaler = NativeScalerWithGradNormCount()

    # 默认Step=1
    if config.TRAIN.ACCUMULATION_STEPS > 1:
        lr_scheduler = build_scheduler(config, optimizer, len(
            data_loader_train) // config.TRAIN.ACCUMULATION_STEPS)
    else:
        lr_scheduler = build_scheduler(
            config, optimizer, len(data_loader_train))

    # 设置损失函数
    # 数据增强: 默认MIXUP=0.8/LabelSmoothing=0.1
    if config.AUG.MIXUP > 0.:
        # smoothing is handled with mixup label transform
        criterion = SoftTargetCrossEntropy()
    elif config.MODEL.LABEL_SMOOTHING > 0.:
        criterion = LabelSmoothingCrossEntropy(
            smoothing=config.MODEL.LABEL_SMOOTHING)
    else:
        criterion = torch.nn.CrossEntropyLoss()

    if config.MTL:
        # loss_ft 设置损失函数
        loss_ft = torch.nn.ModuleDict(
            {task: get_loss(config['TASKS_CONFIG'], task, config) for task in config.TASKS})
        # all_loss_weights 存放每个任务的损失权重，控制各任务对总损失的贡献
        # 比如 edge 任务的损失可能本身非常小，需要乘上较大的权重（如 50.0）来让其影响力与其他任务相当。
        all_loss_weights = {
            'depth': 1.0,
            'semseg': 1.0,
            'human_parts': 2.0,
            'sal': 5.0,
            'edge': 50.0,
            'normals': 10.0,
        }
        # all_loss_weights = {
        #     'depth': 1.0,
        #     'semseg': 1.0,
        #     'human_parts': 1.0,
        #     'sal': 1.0,
        #     'edge': 1.0,
        #     'normals': 1.0,
        # }

        # DWA 新增 如果启用 DWA，则替换为动态权重
        # if config.TRAIN.USE_DWA:
        #     logger.info("Using Dynamic Weight Average (DWA) for multi-task loss weights")
        #     logger.info(f"DWA_TEMPERATURE: {config.TRAIN.DWA_TEMPERATURE}")
        #     dwa_coefficients = init_dwa_coefficients(config.TASKS)

        loss_weights = {}
        for t in config.TASKS:
            loss_weights[t] = all_loss_weights[t]

        # 多任务损失函数
        criterion = MultiTaskLoss(config.TASKS, loss_ft, loss_weights)

    max_accuracy = 0.0

    # 自动从CHECK_POINT恢复 默认False
    if config.TRAIN.AUTO_RESUME:
        resume_file = auto_resume_helper(config.OUTPUT)
        if resume_file:
            if config.MODEL.RESUME:
                logger.warning(
                    f"auto-resume changing resume file from {config.MODEL.RESUME} to {resume_file}")
            config.defrost()
            config.MODEL.RESUME = resume_file
            config.freeze()
            logger.info(f'auto resuming from {resume_file}')
        else:
            logger.info(
                f'no checkpoint found in {config.OUTPUT}, ignoring auto resume')

    # 加载预训练模型, 并且冻结预训练模型只训练LoRA模型
    if config.MODEL.RESUME:
        max_accuracy = load_checkpoint(
            config, model_without_ddp, optimizer, lr_scheduler, loss_scaler, logger)

        if not config.SKIP_INITIAL_EVAL:
            validate(config, data_loader_val, model, 0)
        if config.EVAL_MODE:
            return

    # 使用全量微调来恢复CHECK_POINT
    if config.MODEL.RESUME_BACKBONE:
        max_accuracy = load_checkpoint(
            config, model_without_ddp.backbone, optimizer, lr_scheduler, loss_scaler, logger, True)
        if config.EVAL_MODE:
            validate(config, data_loader_val, model, 0)
            return

    if config.EVAL_MODE:
        validate(config, data_loader_val, model, 0)
        return

    if teacher is not None:
        print("loading teacher.......")
        load_checkpoint(config, teacher, optimizer, lr_scheduler,
                        loss_scaler, logger, quiet=True)

    # 加载预训练模型并进行初始验证
    if config.MODEL.PRETRAINED and (not config.MODEL.RESUME):
        load_pretrained(config, model_without_ddp, logger)
        if not config.SKIP_INITIAL_EVAL:
            acc1, _, _ = validate(config, data_loader_val, model, 0)

    # 仅计算模型吞吐性能
    if config.THROUGHPUT_MODE:  # 仅计算吞吐量
        throughput(data_loader_val, model, logger)
        return

    # 设置微调模式
    if config.MODEL.MTLORA.ENABLED:
        # 进行LoRA微调, 根据配置设置冻结的层
        if config.MODEL.MTLORA.FREEZE_PRETRAINED:
            logger.info("\nMarking LoRA params only as trainable:")
            mark_only_lora_as_trainable(model.backbone,
                                        bias=config.MODEL.MTLORA.BIAS,
                                        freeze_patch_embed=config.TRAIN.FREEZE_PATCH_EMBED,
                                        freeze_norm=config.TRAIN.FREEZE_LAYER_NORM,
                                        free_relative_bias=config.TRAIN.FREEZE_RELATIVE_POSITION_BIAS,
                                        freeze_downsample_reduction=True if config.MODEL.MTLORA.DOWNSAMPLER_ENABLED else config.TRAIN.FREEZE_DOWNSAMPLE_REDUCTION)
        else:  # 进行全量微调
            logger.info("Marking all layers as trainable")
    # 进行LoRA微调, 冻结骨干网络
    if config.MODEL.FREEZE_BACKBONE:
        assert (not config.MODEL.MTLORA.ENABLED)
        print("Freezing backbone.........")
        model.freeze_backbone()
    trainable_params = sum(p.numel()
                           for p in model.parameters() if p.requires_grad)
    lora_params = sum(p.numel() for name, p in model.named_parameters() if p.requires_grad and 'lora' in name)
    prefix_params = sum(p.numel() for name, p in model.named_parameters() if p.requires_grad and 'prefix' in name)
    adapter_params = sum(p.numel() for name, p in model.named_parameters() if p.requires_grad and 'adapter' in name)
    total_model_params = sum(p.numel() for p in model.parameters())
    total_model_params_without_lora = total_model_params - lora_params
    decoder_params = sum(p.numel() for name, p in model.named_parameters() if 'backbone' not in name)

    logger.info(f"""
Number of trainable params: {trainable_params:,}
Decoder params:             {decoder_params:,}
LoRA params:                {lora_params:,}
Prefix params:                {prefix_params:,}
Adapter params:                {adapter_params:,}
Extra params:                {(trainable_params - (lora_params + decoder_params + adapter_params + prefix_params)):,}
Total params:               {total_model_params:,} (trainable ratio: {trainable_params / total_model_params * 100:2.2f}%)
Total params without LoRA:  {total_model_params_without_lora:,} (trainable ratio: {trainable_params / total_model_params_without_lora * 100:2.2f}%)
""")
    logger.info("Start training")
    start_time = time.perf_counter()

    # attn = model.backbone.layers[0].blocks[0].attn
    # print("qkv.weight.requires_grad =====================================================", attn.qkv.weight.requires_grad)
    # print("qkv.bias.requires_grad   =====================================================", attn.qkv.bias.requires_grad)
    # 训练开始
    epoch = 0
    for epoch in range(config.TRAIN.EPOCHS):
        if not config.MTL:
            data_loader_train.sampler.set_epoch(epoch)

        train_one_epoch(config, model, criterion, data_loader_train, optimizer, epoch, mixup_fn, lr_scheduler,
                        loss_scaler, teacher=teacher,
                        # dwa_coefficients=dwa_coefficients if config.TRAIN.USE_DWA else None
                        )
        # if (dist.get_rank() == 0 and
        # CHECK_POINT
        if epoch % config.SAVE_FREQ == 0 or epoch == (config.TRAIN.EPOCHS - 1):
            save_checkpoint(config, epoch, model_without_ddp, max_accuracy, optimizer, lr_scheduler, loss_scaler,
                            logger)
        if epoch % config.EVAL_FREQ == 0 or (not args.no_eval_50 and epoch == 50):
            if config.MTL:
                validate(config, data_loader_val, model, epoch)
            else:
                acc1, _, _ = validate(config, data_loader_val, model, epoch)
                max_accuracy = max(max_accuracy, acc1)

    # final eval
    validate(config, data_loader_val, model, epoch)
    total_time = time.perf_counter() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    logger.info('Training time {}'.format(total_time_str))


def train_one_epoch(config, model, criterion, data_loader, optimizer, epoch, mixup_fn, lr_scheduler, loss_scaler,
                    task=None, teacher=None, dwa_coefficients=None):
    model.train()
    optimizer.zero_grad()

    num_steps = len(data_loader)
    batch_time = AverageMeter()  # TIMM中的模块, 用于计算并存储平均值和当前值
    loss_meter = AverageMeter()
    norm_meter = AverageMeter()
    scaler_meter = AverageMeter()

    performance_meter = PerformanceMeter(config, config.DATA.DBNAME)

    start = time.perf_counter()
    end = time.perf_counter()
    loss_dict = None

    # DWA 新增
    # loss_dict_per_task = {}
    # for task in config.TASKS:
    #     loss_dict_per_task[task] = []

    for idx, batch in enumerate(data_loader):
        if not config.MTL:
            samples, targets = batch
            samples = samples.cuda(non_blocking=True)
            targets = targets.cuda(non_blocking=True)
        else:
            samples = batch['image'].cuda(non_blocking=True)
            targets = {task: batch[task].cuda(
                non_blocking=True) for task in config.TASKS}

        if mixup_fn is not None:
            samples, targets = mixup_fn(samples, targets)
        with torch.cuda.amp.autocast(enabled=config.AMP_ENABLE):
            outputs = model(samples)
            loss, loss_dict = criterion(outputs, targets)
            # print(f"loss dict=========>{loss_dict}")

        # DWA 新增: 记录每个任务的损失值
        # if config.TRAIN.USE_DWA:
        #     for task in config.TASKS:
        #         if task in loss_dict:
        #             loss_dict_per_task[task].append(loss_dict[task].item())

        is_second_order = hasattr(
            optimizer, 'is_second_order') and optimizer.is_second_order
        grad_norm = loss_scaler(loss, optimizer, clip_grad=config.TRAIN.CLIP_GRAD,
                                parameters=model.parameters(), create_graph=is_second_order,
                                update_grad=(idx + 1) % config.TRAIN.ACCUMULATION_STEPS == 0)
        if (idx + 1) % config.TRAIN.ACCUMULATION_STEPS == 0:
            optimizer.zero_grad()
            lr_scheduler.step_update(
                (epoch * num_steps + idx) // config.TRAIN.ACCUMULATION_STEPS)
        loss_scale_value = loss_scaler.state_dict()["scale"]

        # torch.cuda.synchronize()

        if not config.MTL:
            loss_meter.update(loss.item(), targets.size(0))
        else:
            loss_meter.update(loss.item())

        if grad_norm is not None:  # loss_scaler return None if not update
            norm_meter.update(grad_norm)
        scaler_meter.update(loss_scale_value)
        batch_time.update(time.perf_counter() - end)
        end = time.perf_counter()
        if wandb_available:
            metrics = {
                "train/epoch_ndx": epoch,
                "train/batch_ndx": idx,
                "train/train_loss": loss_meter.val,
                "train/train_loss_avg": loss_meter.avg,
                "train/learning_rate": optimizer.param_groups[0]["lr"],
                "train/weight_decay": optimizer.param_groups[0]['weight_decay'],
                "train/time": batch_time.val,
                "train/time_avg": batch_time.avg,
                "train/grad_norm": norm_meter.val,
                "train/grad_norm_avg": norm_meter.avg,
                "train/loss_scale": scaler_meter.val,
                "train/loss_scale_avg": scaler_meter.avg,
                "train/memory": torch.cuda.max_memory_allocated() / (1024.0 * 1024.0),
            }
            if loss_dict is not None:
                for task, task_loss in loss_dict.items():
                    metrics[f"train/tasks/{task}/loss"] = task_loss.item()
            wandb.log(metrics)

        if idx % config.PRINT_FREQ == 0:
            lr = optimizer.param_groups[0]['lr']
            wd = optimizer.param_groups[0]['weight_decay']
            memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0)
            etas = batch_time.avg * (num_steps - idx)
            logger.info(
                f'Train: [{epoch}/{config.TRAIN.EPOCHS}][{idx}/{num_steps}]\t'
                # TODO LYC 修改 显示lr后8位小数
                f'eta {datetime.timedelta(seconds=int(etas))} lr {lr:.8f}\t wd {wd:.4f}\t'
                # TODO LYC 修改 将时间 * 10
                f'time {batch_time.val * 10 :.4f} ({batch_time.avg * 10:.4f})\t'
                f'loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t'
                f'grad_norm {norm_meter.val:.4f} ({norm_meter.avg:.4f})\t'
                f'loss_scale {scaler_meter.val:.4f} ({scaler_meter.avg:.4f})\t'
                f'mem {memory_used:.0f}MB')

    if config.EVAL_TRAINING is not None and (epoch % config.EVAL_TRAINING == 0):
        print("Training Eval:")
        performance_meter.update(
            {t: get_output(outputs[t], t) for t in config.TASKS}, targets)

        scores = performance_meter.get_score(verbose=True)
        if wandb_available:
            scores_logs = {
                "train/epoch": epoch,
            }
            if 'semseg' in scores:
                scores_logs["train/tasks/semseg/mIoU"] = scores['semseg']['mIoU']
            if 'normals' in loss_dict:
                scores_logs["train/tasks/normals/mean"] = scores['normals']['mean']
                scores_logs["train/tasks/normals/rmse"] = scores['normals']['rmse']
                scores_logs["train/tasks/normals/mean_v2"] = scores['normals']['mean_v2']
                scores_logs["train/tasks/normals/rmse_v2"] = scores['normals']['rmse_v2']
            if 'human_parts' in loss_dict:
                scores_logs["train/tasks/human_parts/mIoU"] = scores['human_parts']['mIoU']
            if 'sal' in loss_dict:
                scores_logs["train/tasks/sal/maxF"] = scores['sal']['maxF']
                scores_logs["train/tasks/sal/Beta maxF"] = scores['sal']['Beta maxF']
                scores_logs["train/tasks/sal/mIoU"] = scores['sal']['mIoU']
            if 'edge' in loss_dict:
                scores_logs["train/tasks/sal/loss"] = scores['edge']['loss']
            if 'depth' in loss_dict:
                scores_logs["train/tasks/depth/rmse"] = scores['depth']['rmse']
                scores_logs["train/tasks/depth/log_rmse"] = scores['depth']['log_rmse']

            wandb.log(scores_logs)

        # 新增的打印
        logger.info("Training Eval:")
        # 使用 logger 输出评估指标
        logger.info(f"Epoch {epoch} Training Performance:")
        for task, metrics_dict in scores.items():
            logger.info(f"{task.upper()} Metrics:")
            for metric_name, value in metrics_dict.items():
                logger.info(f"  {metric_name}: {value:.4f}")

    # DWA 新增 权重更新
    # if config.TRAIN.USE_DWA and epoch >= config.TRAIN.DWA_START_EPOCH:
    #     losses_avg = {}
    #     for task in config.TASKS:
    #         if len(loss_dict_per_task[task]) > 0:
    #             losses_avg[task] = np.nanmean(loss_dict_per_task[task])
    #         else:
    #             losses_avg[task] = 1e3  # 如果没有记录到损失值，使用一个较大的值表示, 强迫其被认为是“需要关注”的任务。
    #     updated_weights = compute_dwa_coefficients(dwa_coefficients, losses_avg, T=config.TRAIN.DWA_TEMPERATURE)
    #     for task, w in updated_weights.items():
    #         criterion.loss_weights[task] = w
    #     logger.info(f"EPOCH {epoch} DWA: {updated_weights}")
    #     # for task, w in criterion.loss_weights.items():
    #     #     updated_weights[task] = updated_weights[task] * w
    #     #     criterion.loss_weights[task] = updated_weights[task]
    #     # logger.info(f"EPOCH {epoch} loss_weights: {updated_weights}")

    epoch_time = time.perf_counter() - start
    logger.info(
        f"EPOCH {epoch} training takes {datetime.timedelta(seconds=int(epoch_time))}")


@torch.no_grad()
def validate(config, data_loader, model, epoch):
    """ Evaluate model in an online fashion without storing the predictions to disk """
    tasks = config.TASKS
    performance_meter = PerformanceMeter(config, config.DATA.DBNAME)
    loss_meter = AverageMeter()

    loss_ft = torch.nn.ModuleDict(
        {task: get_loss(config['TASKS_CONFIG'], task, config) for task in config.TASKS})
    all_loss_weights = {
        'depth': 1.0,
        'semseg': 1.0,
        'human_parts': 2.0,
        'sal': 5.0,
        'edge': 50.0,
        'normals': 10.0,
    }
    # all_loss_weights = {
    #     'depth': 1.0,
    #     'semseg': 1.0,
    #     'human_parts': 1.0,
    #     'sal': 1.0,
    #     'edge': 1.0,
    #     'normals': 1.0,
    # }

    loss_weights = {}
    for t in config.TASKS:
        loss_weights[t] = all_loss_weights[t]
    criterion = MultiTaskLoss(config.TASKS, loss_ft, loss_weights)

    model.eval()
    num_val_points = 0
    logger.info("Start eval")
    start = time.perf_counter()
    for i, batch in enumerate(data_loader):
        # Forward pass
        logger.debug(f"Image ID = {batch['meta']['image']}")
        images = batch['image'].cuda(non_blocking=True)
        targets = {task: batch[task].cuda(
            non_blocking=True) for task in tasks}

        output = model(images)

        num_val_points += 1

        # Measure performance
        with torch.cuda.amp.autocast(enabled=config.AMP_ENABLE):
            loss, loss_dict = criterion(
                output, targets)
            loss_meter.update(loss.item())
        processed_output = {t: get_output(
            output[t], t) for t in tasks}
        performance_meter.update(processed_output, targets)
        if wandb_available:
            metrics = {
                "val/epoch_ndx": epoch,
                "val/batch_ndx": i,
                "val/val_loss": loss_meter.val,
                "val/val_loss_avg": loss_meter.avg,
            }
            if loss_dict is not None:
                for task, task_loss in loss_dict.items():
                    metrics[f"val/tasks/{task}/loss"] = task_loss.item()
            wandb.log(metrics)

    logger.info(f"val loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t")

    eval_results = performance_meter.get_score(verbose=True)
    epoch_time = time.perf_counter() - start
    logger.info(
        f"eval takes {datetime.timedelta(seconds=int(epoch_time))}")

    scores_logs = {
        "val/epoch": epoch,
    }
    # 添加计算\Delta m的逻辑
    delta_semseg, base_semseg = 0, 0.6721
    delta_norm, base_norm = 0, 17.97
    delta_human_parts, base_human_parts = 0, 0.6193
    delta_sal, base_sal = 0, 0.6235
    tasks_num = 0
    if 'semseg' in eval_results:
        scores_logs["val/tasks/semseg/mIoU"] = eval_results['semseg']['mIoU']
        delta_semseg = eval_results['semseg']['mIoU'] - base_semseg
        tasks_num += 1
    if 'normals' in eval_results:
        scores_logs["val/tasks/normals/mean"] = eval_results['normals']['mean']
        scores_logs["val/tasks/normals/rmse"] = eval_results['normals']['rmse']
        scores_logs["val/tasks/normals/mean_v2"] = eval_results['normals']['mean_v2']
        scores_logs["val/tasks/normals/rmse_v2"] = eval_results['normals']['rmse_v2']
        delta_norm = base_norm - eval_results['normals']['rmse']
        tasks_num += 1
    if 'human_parts' in eval_results:
        scores_logs["val/tasks/human_parts/mIoU"] = eval_results['human_parts']['mIoU']
        delta_human_parts = eval_results['human_parts']['mIoU'] - base_human_parts
        tasks_num += 1
    if 'sal' in eval_results:
        scores_logs["val/tasks/sal/maxF"] = eval_results['sal']['maxF']
        scores_logs["val/tasks/sal/Beta maxF"] = eval_results['sal']['Beta maxF']
        scores_logs["val/tasks/sal/mIoU"] = eval_results['sal']['mIoU']
        delta_sal = eval_results['sal']['mIoU'] - base_sal
        tasks_num += 1
    if 'edge' in eval_results:
        scores_logs["val/tasks/sal/loss"] = eval_results['edge']['loss']
    if 'depth' in eval_results:
        scores_logs["val/tasks/depth/rmse"] = eval_results['depth']['rmse']
        scores_logs["val/tasks/depth/log_rmse"] = eval_results['depth']['log_rmse']

    delta_m = (delta_semseg / base_semseg + delta_sal / base_sal + delta_norm / base_norm
               + delta_human_parts / base_human_parts) / tasks_num
    scores_logs["val/delta-m"] = delta_m

    if wandb_available:
        wandb.log(scores_logs)
        print(f"delta-m:  {delta_m}")
    else:
        print(f"delta-m:  {delta_m}")

    return eval_results


@torch.no_grad()
def throughput(data_loader, model, logger):
    model.eval()

    for idx, (images, _) in enumerate(data_loader):
        images = images.cuda(non_blocking=True)
        batch_size = images.shape[0]
        for i in range(50):
            model(images)
        # torch.cuda.synchronize()
        logger.info(f"throughput averaged with 30 times")
        tic1 = time.perf_counter()
        for i in range(30):
            model(images)
        # torch.cuda.synchronize()
        tic2 = time.perf_counter()
        logger.info(
            f"batch_size {batch_size} throughput {30 * batch_size / (tic2 - tic1)}")
        return


def get_least_used_gpu():
    pynvml.nvmlInit()
    device_count = pynvml.nvmlDeviceGetCount()
    free_mem = []
    for i in range(device_count):
        handle = pynvml.nvmlDeviceGetHandleByIndex(i)
        info = pynvml.nvmlDeviceGetMemoryInfo(handle)
        free_mem.append((info.free, i))
    # 按 free memory 降序，取第一个
    _, best_gpu = sorted(free_mem, reverse=True)[0]
    pynvml.nvmlShutdown()
    return best_gpu


if __name__ == '__main__':
    args, config = parse_option()

    if config.AMP_OPT_LEVEL:
        print("[warning] Apex amp has been deprecated, please use pytorch amp instead!")

    if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
        rank = int(os.environ["RANK"])
        world_size = int(os.environ['WORLD_SIZE'])
        print(f"RANK and WORLD_SIZE in environ: {rank}/{world_size}")
    else:
        rank = -1
        world_size = -1

    device = torch.device(f'cuda:{get_least_used_gpu()}')
    print(f"device ===============> {device}")
    torch.cuda.set_device(device)
    # torch.distributed.init_process_group(
    #     backend='nccl', init_method='env://', world_size=world_size, rank=rank)
    # torch.distributed.barrier()

    # 设置随机种子来保证实验的可重复性.
    seed = config.SEED
    # + dist.get_rank()
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    cudnn.benchmark = True # cuDNN 会根据你的网络配置和输入数据大小，自动寻找最优的卷积实现（不同算法的速度不同），从而提高运行速度。

    # 设置学习率
    # linear scale the learning rate according to total batch size, may not be optimal
    linear_scaled_lr = config.TRAIN.BASE_LR * config.DATA.BATCH_SIZE / 512.0
    # * dist.get_world_size()
    linear_scaled_warmup_lr = config.TRAIN.WARMUP_LR * config.DATA.BATCH_SIZE / 512.0
    # * dist.get_world_size()
    linear_scaled_min_lr = config.TRAIN.MIN_LR * config.DATA.BATCH_SIZE / 512.0
    # * dist.get_world_size()
    # gradient accumulation also need to scale the learning rate
    """梯度累积步数是指在一次权重更新之前，累积多个小批次（mini-batch）的梯度的次数。
    使用大的 batch size 会带来更稳定的梯度估计、更好的训练效果, 但是有可能因为显存不足, 所以只能进行小批次的训练.
    这里可以通过累计批次来进行梯度的回传
    以ACCUMULATION_STEPS=4为例子:
    batch1: forward -> loss1 -> backward -> 保存梯度1 -> 不更新权重
    batch2: forward -> loss2 -> backward -> 梯度1 + 梯度2 -> 不更新权重
    batch3: forward -> loss3 -> backward -> 梯度1 + 梯度2 + 梯度3 -> 不更新权重
    batch4: forward -> loss4 -> backward -> 梯度1 + 梯度2 + 梯度3 + 梯度4-> 更新权重
    """
    if config.TRAIN.ACCUMULATION_STEPS > 1:
        linear_scaled_lr = linear_scaled_lr * config.TRAIN.ACCUMULATION_STEPS
        linear_scaled_warmup_lr = linear_scaled_warmup_lr * config.TRAIN.ACCUMULATION_STEPS
        linear_scaled_min_lr = linear_scaled_min_lr * config.TRAIN.ACCUMULATION_STEPS
    config.defrost()  # 解锁config中的参数
    config.TRAIN.BASE_LR = linear_scaled_lr
    config.TRAIN.WARMUP_LR = linear_scaled_warmup_lr
    config.TRAIN.MIN_LR = linear_scaled_min_lr
    config.freeze()  # 锁定config中的参数

    os.makedirs(config.OUTPUT, exist_ok=True)
    logger = create_logger(output_dir=config.OUTPUT,
                           # dist_rank=dist.get_rank(),
                           name=f"{config.MODEL.NAME}")

    # if dist.get_rank() == 0:
    #     path = os.path.join(config.OUTPUT, "config.json")
    #     with open(path, "w") as f:
    #         f.write(config.dump())
    #     logger.info(f"Full config saved to {path}")

    logger.info(config.dump())
    logger.info(json.dumps(vars(args)))

    # 修改 if args.disable_wandb:
    if args.disable_wandb:
        wandb_available = False
        logger.info("Wandb logging disabled.")
    elif wandb_available:
        try:
            if not os.getenv("WANDB_API_KEY"):
                wandb.login()
            else:
                wandb.login(key=os.getenv("WANDB_API_KEY"))
            config_name = f"{os.path.basename(os.path.dirname(args.cfg))}/{os.path.basename(args.cfg)}"
            wandb.init(project='MTLoRA', config=config,
                       name=config_name if not args.run_name else args.run_name)
            wandb.config.update({'args': vars(args)})
        except Exception:
            logger.warnning("Could not initialize wandb. Logging is disabled.")
            wandb_available = False

    main(config)
